"""Driver for discrete-gauge Wilson-loop simulations.

This module implements a single entry point, ``run_discrete_gauge``, which
performs a simulation for a given set of sweeping parameters (b, k, n0,
L) and configuration.  It loads flip-count data and per‑gauge kernels,
computes the gauge potential A by combining the logistic and linear
mappings with the coupling ``b`` and then multiplies by the kernel.
The resulting link variables U are obtained via the matrix exponential,
Wilson loops are measured and a string tension is fitted.  Results are
appended to a CSV summary.
"""

from __future__ import annotations

import os
import re
from typing import Dict, List
import numpy as np

from .build_lattice import build_lattice
from .compute_Amu import logistic_D, linear_gD
from .compute_Umu import compute_U_from_A
from .measure_wilson import measure_wilson_loops
from .plot_results import fit_string_tension
from sim_utils import save_csv


def _extract_gauges_from_cfg(cfg: Dict) -> List[str]:
    """Determine which gauge groups to run based on the configuration.

    Priority:
      1. cfg["crossover_analysis"]["gauge_groups"] if provided
      2. keys of cfg["kernel_path_template"] if it is a dict
      3. infer from string template: if it contains ``{gauge}``, return
         default ['U1','SU2','SU3']; otherwise try to extract the gauge from
         the filename (e.g., kernel_SU2_L{L}.npy → ['SU2']).

    Parameters
    ----------
    cfg : dict
        Top‑level configuration dictionary.

    Returns
    -------
    list of str
        Ordered list of gauge group names to run.
    """
    ca = cfg.get("crossover_analysis", {}) or {}
    gg = ca.get("gauge_groups")
    if gg:
        return list(gg)
    kpt = cfg.get("kernel_path_template")
    if isinstance(kpt, dict):
        return list(kpt.keys())
    if isinstance(kpt, str):
        if "{gauge}" in kpt:
            return ["U1", "SU2", "SU3"]
        m = re.search(r"kernel_(U1|SU2|SU3)_", os.path.basename(kpt))
        return [m.group(1)] if m else ["U1"]
    return ["U1", "SU2", "SU3"]


def _resolve_kernel_path(cfg: Dict, gauge: str, L: int) -> str:
    """Resolve the kernel file path for a given gauge and lattice size.

    The kernel path template may be a dict mapping gauge names to templates
    or a single string containing ``{gauge}`` and/or ``{L}`` placeholders.
    Raises a runtime error if the resolved filename does not contain the
    expected gauge tag.
    """
    kp = cfg.get("kernel_path_template")
    path: str | None = None
    if isinstance(kp, dict):
        if gauge not in kp:
            raise KeyError(f"kernel_path_template missing gauge '{gauge}'")
        path = kp[gauge]
    elif isinstance(kp, str):
        path = kp
    if path is None:
        kpaths = cfg.get("kernel_paths", {})
        if not isinstance(kpaths, dict) or gauge not in kpaths:
            raise KeyError(
                "No usable kernel path template in config (expected kernel_path_template or kernel_paths)"
            )
        path = kpaths[gauge]
    # Format the template with gauge and lattice size
    path = path.format(gauge=gauge, L=L)
    # Ensure the filename includes the gauge tag to avoid mismatches
    base = os.path.basename(path)
    tag = f"kernel_{gauge}_"
    if tag not in base:
        raise RuntimeError(
            f"Gauge/path mismatch: gauge={gauge} but resolved '{base}'. "
            f"Check your config's kernel_path_template."
        )
    return path


def _resolve_flip_counts_path(cfg: Dict, L: int) -> str:
    """Resolve the flip‑counts file path for a given lattice size."""
    fc_cfg = cfg.get("flip_counts_path_template") or cfg.get("flip_counts_path")
    if not fc_cfg:
        raise KeyError(
            "missing flip_counts_path_template / flip_counts_path in config"
        )
    if isinstance(fc_cfg, dict):
        if "template" in fc_cfg:
            tpl = fc_cfg["template"]
        elif L in fc_cfg:
            tpl = fc_cfg[L]
        elif "default" in fc_cfg:
            tpl = fc_cfg["default"]
        else:
            tpl = next(iter(fc_cfg.values()))
        return tpl.format(L=L)
    if isinstance(fc_cfg, str):
        return fc_cfg.format(L=L)
    raise TypeError("flip_counts path config must be string or dict")


def run_discrete_gauge(
    *,
    b: float,
    k: float,
    n0: float,
    L: int,
    cfg: Dict,
    output_dir: str,
) -> None:
    """Run a discrete‑gauge simulation for a single parameter set.

    Parameters
    ----------
    b : float
        Coupling constant g.  This scales the pivot weight g(D).
    k : float
        Logistic slope used to compute the fractal dimension D.
    n0 : float
        Logistic center used to compute the fractal dimension D.
    L : int
        Linear size of the lattice (number of sites along one dimension).
    cfg : dict
        Configuration dictionary.  Must contain kernel and flip‑counts
        templates as described in the README.
    output_dir : str
        Directory where the summary CSV will be written.
    """
    os.makedirs(output_dir, exist_ok=True)
    # Build the lattice: returns a list of links of length 2*L^2
    lattice = build_lattice(L)
    N_links = len(lattice)
    # Load flip counts
    fc_path = _resolve_flip_counts_path(cfg, L=L)
    if not os.path.exists(fc_path):
        raise FileNotFoundError(f"flip_counts file not found: {fc_path!r}")
    flip_counts = np.load(fc_path, allow_pickle=True)
    if flip_counts.size != N_links:
        raise ValueError(
            f"flip_counts for L={L} has {flip_counts.size} entries, expected {N_links}"
        )
    # Pivot parameters from cfg
    pivot_cfg = cfg.get("pivot", {})
    a_shape = float(pivot_cfg.get("a", 1.0))
    b_pivot = float(pivot_cfg.get("b", 0.0))
    # Sweep parameters
    k_log = float(k)
    n0_log = float(n0)
    g_scale = float(b)
    loop_sizes = cfg.get("loop_sizes", [])
    # Determine which gauge groups to run
    gauges = _extract_gauges_from_cfg(cfg)
    summary_rows: List[Dict] = []
    # Precompute dimension and pivot weight once since it does not depend on gauge
    D_vals = logistic_D(flip_counts, k_log, n0_log)
    gD_vals = linear_gD(D_vals, a_shape, b_pivot) * g_scale
    for gauge in gauges:
        # Resolve kernel path and load kernel
        kernel_path = _resolve_kernel_path(cfg, gauge=gauge, L=L)
        if not os.path.exists(kernel_path):
            raise FileNotFoundError(f"kernel file not found: {kernel_path!r}")
        K = np.load(kernel_path, allow_pickle=True)
        # Check shape consistency
        if K.ndim == 1 and K.size != N_links:
            raise ValueError(
                f"kernel for gauge {gauge} at L={L} has {K.size} entries, expected {N_links}"
            )
        # Compute gauge potential A: broadcast gD_vals onto kernel
        if K.ndim > 1:
            A = gD_vals[:, np.newaxis, np.newaxis] * K
        else:
            A = gD_vals * K
        # Compute link variables U from A
        U = compute_U_from_A(A, gauge_group=gauge)
        # Measure Wilson loops
        loops = measure_wilson_loops(lattice, U, loop_sizes)
        # Fit string tension and confidence interval
        sigma, ci95 = fit_string_tension(loops)
        # Append to summary
        summary_rows.append({
            "b": b,
            "k": k,
            "n0": n0,
            "L": L,
            "gauge_group": gauge,
            "sigma": float(sigma),
            "ci95": float(ci95),
        })
    # Write summary rows to CSV
    summary_path = os.path.join(output_dir, "string_tension_summary.csv")
    for row in summary_rows:
        save_csv(summary_path, row)